from stable_baselines3.common.policies import MultiInputActorCriticPolicy, register_policy


def get_cutoff_entry(env):
    from stable_baselines3.common.preprocessing import preprocess_obs
    from stable_baselines3.common.utils import obs_as_tensor, get_device
    obs = env.observation_space.sample()
    tensor_obs = obs_as_tensor(obs, get_device())
    cutoff_entry = 0
    for key in obs.keys():
        if key.split(":")[0] == "critic":
            cutoff_entry = cutoff_entry + preprocess_obs(tensor_obs[key], env.observation_space[key]).flatten().shape[0]
    return cutoff_entry


def get_custom_training_algorithm(algorithm, env, n_steps=None):
    if algorithm == "A2C":
        from stable_baselines3 import A2C

        # Custom policy
        # Two shared layers with 64 nodes, then two layers each for pi and vf, also 64 nodes each (tbd)
        # cutoff_entry is now equal to 2 as we want to remove eq/reward step flag that has 2 values

        policy_name = f"CustomPolicy_{id(CustomPolicy)}"
        register_policy(policy_name, CustomPolicy)

        cutoff_entry = get_cutoff_entry(env)

        if n_steps is None:
            m = A2C(env=env, policy=policy_name, policy_kwargs={"cutoff_entry": cutoff_entry})
        else:
            m = A2C(env=env, policy=policy_name, gamma=1, n_steps=n_steps, policy_kwargs={"cutoff_entry": cutoff_entry})

    return m


class CustomPolicy(MultiInputActorCriticPolicy):
    def __init__(self, *args, **kwargs):

        cutoff_entry = kwargs["cutoff_entry"]
        if kwargs.keys().__contains__("cutoff_entry"): del kwargs["cutoff_entry"]

        super(CustomPolicy, self).__init__(
            *args,
            **kwargs,
            net_arch=[dict(pi=[64, 64], vf=[64, 64], cutoff_entry=cutoff_entry)],
        )
